{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "fFjof1NgAJwu"
      },
      "outputs": [],
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A8xNRyZMW1yK"
      },
      "source": [
        "# Use RunInference with TFX Basic Shared Libraries\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HrCtxslBGK8Z"
      },
      "source": [
        "This notebook demonstrates how to use the Apache Beam [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform with TensorFlow and [TFX Basic Shared Libraries](https://github.com/tensorflow/tfx-bsl) (`tfx-bsl`).\n",
        "\n",
        "Use this approach when your trained model uses a `tf.Example` input.\n",
        "If you have `numpy` or `tf.Tensor` inputs, see the [Apache Beam RunInference with TensorFlow](https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb) notebook, which shows how to use the built-in TensorFlow model handlers.\n",
        "\n",
        "The Apache Beam RunInference transform accepts a model handler generated from [`tfx-bsl`](https://github.com/tensorflow/tfx-bsl) by using `CreateModelHandler`.\n",
        "\n",
        "This notebook demonstrates how to complete the following tasks:\n",
        "- Import `tfx-bsl`.\n",
        "- Build a simple TensorFlow model.\n",
        "- Set up example data.\n",
        "- Use the `tfx-bsl` model handler with the example data, and get a prediction inside an Apache Beam pipeline.\n",
        "\n",
        "For more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HrCtxslBGK8A"
      },
      "source": [
        "## Before you begin\n",
        "Set up your environment and download dependencies."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HrCtxslBGK8A"
      },
      "source": [
        "### Import `tfx-bsl`\n",
        "First, import `tfx-bsl`.\n",
        "Creating a model handler is supported in `tfx-bsl` versions 1.10 and later."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jBakpNZnAhqk"
      },
      "outputs": [],
      "source": [
        "!pip install tfx_bsl==1.10.0 --quiet\n",
        "!pip install protobuf --quiet\n",
        "!pip install apache_beam[interactive] --quiet"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X80jy3FqHjK4"
      },
      "source": [
        "### Authenticate with Google Cloud\n",
        "This notebook relies on saving your model to Google Cloud. To use your Google Cloud account, authenticate this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "Kz9sccyGBqz3"
      },
      "outputs": [],
      "source": [
        "from google.colab import auth\n",
        "auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "40qtP6zJuMXm"
      },
      "source": [
        "### Import dependencies and set up your bucket\n",
        "Use the following code to import dependencies and to set up your Google Cloud Storage bucket.\n",
        "\n",
        "Replace `PROJECT_ID` and `BUCKET_NAME` with the ID of your project and the name of your bucket.\n",
        "\n",
        "**Important**: If an error occurs, restart your runtime."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "eEle839_Akqx"
      },
      "outputs": [],
      "source": [
        "import argparse\n",
        "\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow_serving.apis import prediction_log_pb2\n",
        "\n",
        "import apache_beam as beam\n",
        "from apache_beam.ml.inference.base import RunInference\n",
        "import tfx_bsl\n",
        "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n",
        "from tfx_bsl.public import tfxio\n",
        "from tfx_bsl.public.proto import model_spec_pb2\n",
        "from tensorflow_metadata.proto.v0 import schema_pb2\n",
        "\n",
        "import numpy\n",
        "\n",
        "from typing import Dict, Text, Any, Tuple, List\n",
        "\n",
        "from apache_beam.options.pipeline_options import PipelineOptions\n",
        "\n",
        "project = \"PROJECT_ID\" # @param {type:'string'}\n",
        "bucket = \"BUCKET_NAME\" # @param {type:'string'}\n",
        "\n",
        "save_model_dir_multiply = f'gs://{bucket}/tfx-inference/model/multiply_five/v1/'\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YzvZWEv-1oiK"
      },
      "source": [
        "## Create and test a simple model\n",
        "\n",
        "This section creates and tests a model that predicts the 5 times multiplication table."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rIwD_qEpX7Gu"
      },
      "source": [
        "### Create the model\n",
        "Create training data, and then build a linear regression model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SH7iq3zeBBJ-",
        "outputId": "c5adb7ec-285b-401e-f9be-1e9b83c6d0ba"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model: \"model\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " x (InputLayer)              [(None, 1)]               0         \n",
            "                                                                 \n",
            " dense (Dense)               (None, 1)                 2         \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 2\n",
            "Trainable params: 2\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ]
        }
      ],
      "source": [
        "# Create training data that represents the 5 times multiplication table for the numbers 0 to 99.\n",
        "# x is the data and y is the labels.\n",
        "x = numpy.arange(0, 100)   # Examples\n",
        "y = x * 5                  # Labels\n",
        "\n",
        "# Build a simple linear regression model.\n",
        "# Note that the model has a shape of (1) for its input layer and expects a single int64 value.\n",
        "input_layer = keras.layers.Input(shape=(1), dtype=tf.float32, name='x')\n",
        "output_layer= keras.layers.Dense(1)(input_layer)\n",
        "\n",
        "model = keras.Model(input_layer, output_layer)\n",
        "model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O_a0-4Gb19cy"
      },
      "source": [
        "### Test the model\n",
        "\n",
        "This step tests the model that you created."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5XkIYXhJBFmS",
        "outputId": "e3bb5079-5cb8-4fe4-eb8d-d3d13d5f9f0c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1/1 [==============================] - 0s 94ms/step\n",
            "Test Examples [20, 40, 60, 90]\n",
            "Predictions [[ 9.201942]\n",
            " [16.40566 ]\n",
            " [23.609379]\n",
            " [34.41496 ]]\n"
          ]
        }
      ],
      "source": [
        "model.fit(x, y, epochs=500, verbose=0)\n",
        "test_examples =[20, 40, 60, 90]\n",
        "value_to_predict = numpy.array(test_examples, dtype=numpy.float32)\n",
        "predictions = model.predict(value_to_predict)\n",
        "\n",
        "print('Test Examples ' + str(test_examples))\n",
        "print('Predictions ' + str(predictions))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dEmleqiH3t71"
      },
      "source": [
        "## RunInference with Tensorflow using `tfx-bsl`\n",
        "In versions 1.10.0 and later of `tfx-bsl`, you can\n",
        "create a TensorFlow `ModelHandler` to use with Apache Beam.\n",
        "\n",
        "### Populate the data in a TensorFlow proto\n",
        "\n",
        "Tensorflow data uses protos. If you are loading from a file, helpers exist for this step. Because this example uses generated data, this code populates a proto."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "XvKc9kQilPjx"
      },
      "outputs": [],
      "source": [
        "# This example shows a proto that converts the samples and labels into\n",
        "# tensors usable by TensorFlow.\n",
        "\n",
        "class ExampleProcessor:\n",
        "    def create_example_with_label(self, feature: numpy.float32,\n",
        "                             label: numpy.float32)-> tf.train.Example:\n",
        "        return tf.train.Example(\n",
        "            features=tf.train.Features(\n",
        "                  feature={'x': self.create_feature(feature),\n",
        "                           'y' : self.create_feature(label)\n",
        "                  }))\n",
        "\n",
        "    def create_example(self, feature: numpy.float32):\n",
        "        return tf.train.Example(\n",
        "            features=tf.train.Features(\n",
        "                  feature={'x' : self.create_feature(feature)})\n",
        "            )\n",
        "\n",
        "    def create_feature(self, element: numpy.float32):\n",
        "        return tf.train.Feature(float_list=tf.train.FloatList(value=[element]))\n",
        "\n",
        "# Create a labeled example file for the 5 times table.\n",
        "\n",
        "example_five_times_table = 'example_five_times_table.tfrecord'\n",
        "\n",
        "with tf.io.TFRecordWriter(example_five_times_table) as writer:\n",
        "  for i in zip(x, y):\n",
        "    example = ExampleProcessor().create_example_with_label(\n",
        "        feature=i[0], label=i[1])\n",
        "    writer.write(example.SerializeToString())\n",
        "\n",
        "# Create a file containing the values to predict.\n",
        "\n",
        "predict_values_five_times_table = 'predict_values_five_times_table.tfrecord'\n",
        "\n",
        "with tf.io.TFRecordWriter(predict_values_five_times_table) as writer:\n",
        "  for i in value_to_predict:\n",
        "    example = ExampleProcessor().create_example(feature=i)\n",
        "    writer.write(example.SerializeToString())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G-sAu3cf31f3"
      },
      "source": [
        "### Fit the model\n",
        "\n",
        "This step builds a model. Because RunInference requires pretrained models, this segment builds a usable model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AnbrxXPKeAOQ",
        "outputId": "42439aac-3a10-4e86-829f-44332aad6173"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<keras.callbacks.History at 0x7f6960074c70>"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "RAW_DATA_TRAIN_SPEC = {\n",
        "'x': tf.io.FixedLenFeature([], tf.float32),\n",
        "'y': tf.io.FixedLenFeature([], tf.float32)\n",
        "}\n",
        "\n",
        "dataset = tf.data.TFRecordDataset(example_five_times_table)\n",
        "dataset = dataset.map(lambda e : tf.io.parse_example(e, RAW_DATA_TRAIN_SPEC))\n",
        "dataset = dataset.map(lambda t : (t['x'], t['y']))\n",
        "dataset = dataset.batch(100)\n",
        "dataset = dataset.repeat()\n",
        "\n",
        "model.fit(dataset, epochs=5000, steps_per_epoch=1, verbose=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r4dpR6dQ4JwX"
      },
      "source": [
        "### Save the model\n",
        "\n",
        "This step shows how to save your model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "fYvrIYO3qiJx"
      },
      "outputs": [],
      "source": [
        "RAW_DATA_PREDICT_SPEC = {\n",
        "'x': tf.io.FixedLenFeature([], tf.float32),\n",
        "}\n",
        "\n",
        "# tf.function compiles the function into a callable TensorFlow graph.\n",
        "# RunInference relies on calling a TensorFlow graph as a model.\n",
        "# Note: Use the input signature type tf.string, because it's supported by\n",
        "# tfx-bsl ModelHandlers.\n",
        "@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string , name='examples')])\n",
        "def serve_tf_examples_fn(serialized_tf_examples):\n",
        "  \"\"\"Returns the output to be used in the serving signature.\"\"\"\n",
        "  features = tf.io.parse_example(serialized_tf_examples, RAW_DATA_PREDICT_SPEC)\n",
        "  return model(features, training=False)\n",
        "\n",
        "signature = {'serving_default': serve_tf_examples_fn}\n",
        "\n",
        "# Signatures define the input and output types for a computation. The optional\n",
        "# save signatures argument controls which methods in obj are available to\n",
        "# programs that consume SavedModels, such as serving APIs.\n",
        "# See https://www.tensorflow.org/api_docs/python/tf/saved_model/save\n",
        "tf.keras.models.save_model(model, save_model_dir_multiply, signatures=signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P2UMmbNW4YQV"
      },
      "source": [
        "## Run the pipeline\n",
        "Use the following code to run the pipeline.\n",
        "\n",
        "* `FormatOutput` demonstrates how to extract values from the output protos.\n",
        "* `CreateModelHandler` demonstrates the model handler that needs to be passed into the Apache Beam RunInference API."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 193
        },
        "id": "PzjmXM_KvqHY",
        "outputId": "0aa60bef-52a0-4ce2-d228-3fac977d59e0"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.\n"
          ]
        },
        {
          "data": {
            "application/javascript": "\n        if (typeof window.interactive_beam_jquery == 'undefined') {\n          var jqueryScript = document.createElement('script');\n          jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n          jqueryScript.type = 'text/javascript';\n          jqueryScript.onload = function() {\n            var datatableScript = document.createElement('script');\n            datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n            datatableScript.type = 'text/javascript';\n            datatableScript.onload = function() {\n              window.interactive_beam_jquery = jQuery.noConflict(true);\n              window.interactive_beam_jquery(document).ready(function($){\n                \n              });\n            }\n            document.head.appendChild(datatableScript);\n          };\n          document.head.appendChild(jqueryScript);\n        } else {\n          window.interactive_beam_jquery(document).ready(function($){\n            \n          });\n        }"
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:From /usr/local/lib/python3.9/dist-packages/tfx_bsl/beam/run_inference.py:615: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Use `tf.saved_model.load` instead.\n",
            "WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "example is 20.00 prediction is 104.36\n",
            "example is 40.00 prediction is 202.62\n",
            "example is 60.00 prediction is 300.87\n",
            "example is 90.00 prediction is 448.26\n"
          ]
        }
      ],
      "source": [
        "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n",
        "\n",
        "class FormatOutput(beam.DoFn):\n",
        "    def process(self, element: prediction_log_pb2.PredictionLog):\n",
        "        predict_log = element.predict_log\n",
        "        input_value = tf.train.Example.FromString(predict_log.request.inputs['examples'].string_val[0])\n",
        "        input_float_value = input_value.features.feature['x'].float_list.value[0]\n",
        "        output_value = predict_log.response.outputs\n",
        "        output_float_value = output_value['output_0'].float_val[0]\n",
        "        yield (f\"example is {input_float_value:.2f} prediction is {output_float_value:.2f}\")\n",
        "\n",
        "tfexample_beam_record = tfx_bsl.public.tfxio.TFExampleRecord(file_pattern=predict_values_five_times_table)\n",
        "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=save_model_dir_multiply)\n",
        "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n",
        "model_handler = CreateModelHandler(inference_spec_type)\n",
        "with beam.Pipeline() as p:\n",
        "    _ = (p | tfexample_beam_record.RawRecordBeamSource()\n",
        "           | RunInference(model_handler)\n",
        "           | beam.ParDo(FormatOutput())\n",
        "           | beam.Map(print)\n",
        "        )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IXikjkGdHm9n"
      },
      "source": [
        "## Use `KeyedModelHandler` with `tfx-bsl`\n",
        "\n",
        "By default, the `ModelHandler` does not expect a key.\n",
        "\n",
        "* If you know that keys are associated with your examples, use `beam.KeyedModelHandler` to wrap the model handler.\n",
        "* If you don't know whether keys are associated with your examples, use `beam.MaybeKeyedModelHandler`.\n",
        "\n",
        "In addition to demonstrating how to use a keyed model handler, this step demonstrates how to use `tfx-bsl` examples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "KPtE3fmdJQry",
        "outputId": "c33558fc-fb12-4c20-b828-b5520721f279"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "key 5.0 : example is 5.00 prediction is 30.67\n",
            "key 50.0 : example is 50.00 prediction is 251.75\n",
            "key 40.0 : example is 40.00 prediction is 202.62\n",
            "key 100.0 : example is 100.00 prediction is 497.38\n"
          ]
        }
      ],
      "source": [
        "from apache_beam.ml.inference.base import KeyedModelHandler\n",
        "from google.protobuf import text_format\n",
        "import tensorflow as tf\n",
        "\n",
        "class FormatOutputKeyed(FormatOutput):\n",
        "  # To simplify, inherit from FormatOutput.\n",
        "  def process(self, tuple_in: Tuple):\n",
        "    key, element = tuple_in\n",
        "    output = super().process(element)\n",
        "    yield ' : '.join([key, next(output)])\n",
        "\n",
        "def make_example(num):\n",
        "  # Return keyed values in the form of (key num, example).\n",
        "  key = f'key {num}'\n",
        "  tf_proto = text_format.Parse(\n",
        "    \"\"\"\n",
        "    features {\n",
        "      feature {key: \"x\" value { float_list { value: %f }}}\n",
        "    }\n",
        "    \"\"\"% num, tf.train.Example())\n",
        "  return (key, tf_proto)\n",
        "\n",
        "# Make a list of examples of type tf.train.Example.\n",
        "examples = [\n",
        "    make_example(5.0),\n",
        "    make_example(50.0),\n",
        "    make_example(40.0),\n",
        "    make_example(100.0)\n",
        "]\n",
        "\n",
        "tfexample_beam_record = tfx_bsl.public.tfxio.TFExampleRecord(file_pattern=predict_values_five_times_table)\n",
        "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=save_model_dir_multiply)\n",
        "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n",
        "keyed_model_handler = KeyedModelHandler(CreateModelHandler(inference_spec_type))\n",
        "with beam.Pipeline() as p:\n",
        "    _ = (p | 'CreateExamples' >> beam.Create(examples)\n",
        "           | RunInference(keyed_model_handler)\n",
        "           | beam.ParDo(FormatOutputKeyed())\n",
        "           | beam.Map(print)\n",
        "        )"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "X80jy3FqHjK4",
        "40qtP6zJuMXm",
        "YzvZWEv-1oiK",
        "rIwD_qEpX7Gu",
        "O_a0-4Gb19cy",
        "G-sAu3cf31f3",
        "r4dpR6dQ4JwX",
        "P2UMmbNW4YQV"
      ],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
